{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The basal ganglia\n",
    "\n",
    "The basal ganglia\n",
    "according to [Stewart 2010](\n",
    "http://compneuro.uwaterloo.ca/files/publications/stewart.2010.pdf)\n",
    "is an action selector\n",
    "that chooses whatever action has the best \"salience\" or \"goodness\".\n",
    "Its really interesting behaviour manifests itself\n",
    "when it interacts with the thalamus and other components of the brain,\n",
    "but in this example we will only show the basal ganglia's basic behaviour.\n",
    "It will choose between three actions\n",
    "that we'll pretend are \"eating\", \"sleeping\" and \"playing\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import nengo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Create the Network\n",
    "\n",
    "Here we create the basal ganglia and the action input node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nengo.Network(label=\"Basal Ganglia\")\n",
    "with model:\n",
    "    basal_ganglia = nengo.networks.BasalGanglia(dimensions=3)\n",
    "\n",
    "\n",
    "class ActionIterator:\n",
    "    def __init__(self, dimensions):\n",
    "        self.actions = np.ones(dimensions) * 0.1\n",
    "\n",
    "    def step(self, t):\n",
    "        # one action at time dominates\n",
    "        dominate = int(t % 3)\n",
    "        self.actions[:] = 0.1\n",
    "        self.actions[dominate] = 0.8\n",
    "        return self.actions\n",
    "\n",
    "\n",
    "action_iterator = ActionIterator(dimensions=3)\n",
    "\n",
    "with model:\n",
    "    actions = nengo.Node(action_iterator.step, label=\"actions\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Connect the Network\n",
    "\n",
    "Connect the input to the basal ganglia and connect the probes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with model:\n",
    "    nengo.Connection(actions, basal_ganglia.input, synapse=None)\n",
    "    selected_action = nengo.Probe(basal_ganglia.output, synapse=0.01)\n",
    "    input_actions = nengo.Probe(actions, synapse=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Simulate the Network and Plot the Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    # This will take a while\n",
    "    sim.run(6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[input_actions].argmax(axis=1))\n",
    "plt.ylim(-0.1, 2.1)\n",
    "plt.xlabel(\"time [s]\")\n",
    "plt.title(\"Index of actual max value\")\n",
    "plt.subplot(2, 1, 2)\n",
    "plt.plot(sim.trange(), sim.data[selected_action].argmax(axis=1))\n",
    "plt.ylim(-0.1, 2.1)\n",
    "plt.xlabel(\"time [s]\")\n",
    "plt.title(\"Basal ganglia selected max value\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, the maximum index\n",
    "is found at 0, then 1, then 2\n",
    "or \"eating\", \"sleeping\", then \"playing\".\n",
    "Note that if you zoom in enough on the basal ganglia values,\n",
    "you'll be able to see a bit of a delay between finding max values.\n",
    "If you read the aforementioned paper,\n",
    "you'll see that this is expected and matches previous experiments."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
